import torch
import numpy as np
import torch.nn as nn

class pAUC_CVaR(nn.Module):
    def __init__(self, pos_length, num_neg, threshold=1.0, gamma=0.2, beta=0.9, loss_type = 'sh'):
        '''
        param
        pos_length: number of positive examples for the training data
        num_neg: number of negative samples for each mini-batch
        threshold: margin for basic AUC loss
        gamma: FPR upper bound for pAUC used for SOTA
        beta: stepsize for CVaR regularization term
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_CVaR, self).__init__()
        self.gamma = round(gamma*num_neg)/num_neg
        self.beta = beta
        self.num_neg = num_neg
        self.pos_length = pos_length
        self.lambda_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)
    
    def update_beta(self, decay_factor):
        self.beta = self.beta/decay_factor
    
    def forward(self, y_pred, y_true, index_p, index_n):
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)
        
        mat_data = f_ns.repeat(len(f_ps), 1)
        f_ps = f_ps.view(-1, 1)
        
        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2
        
        loss = neg_loss
        p = loss > self.lambda_pos[index_p]
        

        self.lambda_pos[index_p] = self.lambda_pos[index_p]-self.beta/self.pos_length*(1 - p.sum(dim=1, keepdim=True)/(self.gamma*self.num_neg))
        p.detach_()
        loss = torch.mean(p * loss) / self.gamma
        
        return loss

class pAUC_KL(nn.Module):
    def __init__(self, pos_length, threshold=1.0, beta=0.9, Lambda=1.0, loss_type = 'sh'):
        '''
        param
        pos_length: number of positive examples for the training data
        threshold: margin for basic AUC loss
        beta: moving average parameter
        Lambda: robust regularization parameter
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_KL, self).__init__()
        self.beta = beta
        self.Lambda = Lambda
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_beta(self, decay_factor):
        self.beta = self.beta/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)

        mat_data = f_ns.repeat(len(f_ps), 1)
        f_ps = f_ps.view(-1, 1)

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2 

        loss = neg_loss
        exp_loss = torch.exp(loss/self.Lambda)

        self.u_pos[index_p] = (1 - self.beta) * self.u_pos[index_p] + self.beta * (exp_loss.mean(1, keepdim=True))

        p = exp_loss/self.u_pos[index_p]
        p.detach_()
        loss = torch.mean(p * loss)
        
        return loss


class pAUC_mini(nn.Module):
    def __init__(self, num_neg, threshold=1.0, gamma=0.2, loss_type = 'sh'):
        '''
        param
        num_neg: number of negative samples for each mini-batch
        threshold: margin for basic AUC loss
        gamma: FPR upper bound for pAUC for the mini batch heuristic loss
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_mini, self).__init__()
        self.gamma = round(gamma*num_neg)/num_neg
        self.num_neg = num_neg
        self.threshold = threshold
        self.loss_type = loss_type
        print('Num negative :', self.num_neg)
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)


        partial_arg = torch.topk(f_ns, int(self.num_neg*self.gamma), sorted = False)[1]
        vec_dat = f_ns[partial_arg]
        mat_data = vec_dat.repeat(len(f_ps), 1)
        f_ps = f_ps.view(-1, 1)


        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2
            
        loss = neg_loss
        loss = torch.mean(loss)

        return loss


class P_PUSH(nn.Module):
    def __init__(self, pos_length, threshold=1.0, beta=0.9, poly=2, loss_type = 'sh'):
        '''
        param
        pos_length: number of positive examples for the training data
        threshold: margin for basic AUC loss
        beta: moving average parameter
        poly: polynomial (p) parameter for p-norm push
        loss type: basic AUC loss to apply.
        '''
        super(P_PUSH, self).__init__()

        self.poly = poly
        self.beta = beta
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_beta(self, decay_factor):
        self.beta = self.beta/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(-1)

        mat_data = f_ns.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2
        loss = neg_loss
        self.u_pos[index_p] = (1 - self.beta) * self.u_pos[index_p] + self.beta * (loss.mean(1, keepdim=True))

        p = self.poly*(self.u_pos[index_p]**(self.poly-1))
        p.detach_()
        loss = torch.mean(p * loss)

        return loss


class pAUC_KL_two(nn.Module):
    def __init__(self, pos_length, Lambda=1.0, tau=1.0, threshold=1.0, beta_1=0.9, beta_2=0.9, loss_type = 'sh'):
        '''
        param
        pos_length: number of positive examples for the training data
        Lambda: robust regularization parameter on negative sample
        tau: robust regularization parameter on positive sample
        threshold: margin for basic AUC loss
        beta_1: moving average parameter on negative sample
        beta_2: moving average parameter on positive sample
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_KL_two, self).__init__()
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.Lambda = Lambda
        self.tau = tau
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.w = 0.0
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_beta(self, decay_factor):
        self.beta_1 = self.beta_1/decay_factor
        self.beta_2 = self.beta_2/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(-1)

        mat_data = f_ns.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)
        
        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2

        loss = neg_loss
        exp_loss = torch.exp(loss/self.Lambda)

        self.u_pos[index_p] = (1 - self.beta_1) * self.u_pos[index_p] + self.beta_1 * (exp_loss.mean(1, keepdim=True))

        self.w = (1 - self.beta_2) * self.w + self.beta_2 * (torch.pow(self.u_pos[index_p], self.Lambda/self.tau).mean())
        
        p = torch.pow(self.u_pos[index_p], self.Lambda/self.tau - 1) * exp_loss/self.w
        p.detach_()
        loss = torch.mean(p * loss)

        return loss


class pAUC_mini_two(nn.Module):
    def __init__(self, num_pos, num_neg, threshold=1.0, gamma=0.2, loss_type = 'sh'):
        '''
        param
        num_pos: number of positive samples for each mini-batch
        num_neg: number of negative samples for each mini-batch
        threshold: margin for basic AUC loss
        gamma: FPR upper bound and (1-TPR) upper bound for pAUC for the mini batch heuristic loss
        loss type: basic AUC loss to apply.
        '''
        super(pAUC_mini_two, self).__init__()
        self.gamma = gamma
        self.num_pos = num_pos
        self.num_neg = num_neg
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)


        partial_arg_pos = torch.topk(f_ps, round(self.gamma*self.num_pos),largest=False, sorted = False)[1]
        partial_arg_neg = torch.topk(f_ns, round(self.gamma*self.num_neg),largest=True, sorted = False)[1]
        
        vec_dat = f_ns[partial_arg_neg]
        mat_data = vec_dat.repeat(len(partial_arg_pos), 1)

        selected_ps = f_ps[partial_arg_pos].view(-1, 1)

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (selected_ps - mat_data), torch.zeros_like(mat_data)) ** 2
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (selected_ps - mat_data)) ** 2 
            
        loss = neg_loss
        loss = torch.mean(loss)

        return loss



